package com.akto.gpt.handlers.gpt_prompts;

import com.mongodb.BasicDBObject;
import com.mongodb.BasicDBList;
import javax.validation.ValidationException;
import org.json.JSONObject;
import org.json.JSONArray;

public class AnalyzeVulnerabilityPromptHandler extends PromptHandler {

    @Override
    protected void validate(BasicDBObject queryData) throws ValidationException {
        if (queryData == null) {
            throw new ValidationException("Query data cannot be null");
        }
        if (!queryData.containsField("prompt")) {
            throw new ValidationException("prompt field is required");
        }
    }

    @Override
    protected String getPrompt(BasicDBObject queryData) {
        return queryData.getString("prompt");
    }

    @Override
    protected BasicDBObject processResponse(String rawResponse) {
        BasicDBObject result = new BasicDBObject();
        BasicDBList segments = new BasicDBList();
        
        try {
            if (rawResponse == null || rawResponse.trim().isEmpty()) {
                result.put("vulnerableSegments", segments);
                return result;
            }
            
            // Extract the response field from the JSON structure (like AnalyzeRequestResponseHeaders does)
            String processedResponse = processOutput(rawResponse);
            
            if (processedResponse.equals("NOT_FOUND")) {
                result.put("vulnerableSegments", segments);
                return result;
            }
            
            // Extract JSON from the processed response text
            String vulnerabilityJson = extractJsonFromText(processedResponse);
            if (vulnerabilityJson == null) {
                result.put("vulnerableSegments", segments);
                return result;
            }
            
            // Parse the vulnerability JSON
            JSONObject vulnerabilityData = new JSONObject(vulnerabilityJson);
            segments = parseVulnerableSegments(vulnerabilityData);
            result.put("vulnerableSegments", segments);
            
        } catch (Exception e) {
            result.put("vulnerableSegments", segments);
        }
        
        return result;
    }
    
    private BasicDBList parseVulnerableSegments(JSONObject jsonResponse) {
        BasicDBList segments = new BasicDBList();
        
        try {
            if (jsonResponse == null || !jsonResponse.has("vulnerableSegments")) {
                return segments;
            }
            
            Object vulnerableSegmentsObj = jsonResponse.get("vulnerableSegments");
            
            if (vulnerableSegmentsObj instanceof JSONArray) {
                JSONArray segmentsArray = (JSONArray) vulnerableSegmentsObj;
                
                for (int i = 0; i < segmentsArray.length(); i++) {
                    try {
                        JSONObject segment = segmentsArray.getJSONObject(i);
                        BasicDBObject segmentObj = new BasicDBObject();
                        
                        // Extract start position
                        if (segment.has("start")) {
                            segmentObj.put("start", segment.getInt("start"));
                        } else {
                            segmentObj.put("start", 0);
                        }
                        
                        // Extract end position
                        if (segment.has("end")) {
                            segmentObj.put("end", segment.getInt("end"));
                        } else {
                            segmentObj.put("end", 0);
                        }
                        
                        // Extract phrase
                        if (segment.has("phrase")) {
                            segmentObj.put("phrase", segment.getString("phrase"));
                        } else {
                            segmentObj.put("phrase", "");
                        }
                        
                        segments.add(segmentObj);
                        
                    } catch (Exception e) {
                        logger.error("Error processing segment " + i + ": " + e.getMessage());
                    }
                }
            }
            
        } catch (Exception e) {
            logger.error("Error parsing vulnerable segments: " + e.getMessage(), e);
        }
        
        return segments;
    }
    
    
    /**
     * Extract JSON object from text that may contain other content
     */
    private String extractJsonFromText(String text) {
        if (text == null || text.trim().isEmpty()) {
            return null;
        }
        
        // Look for JSON patterns in the text
        // Try multiple approaches to find valid JSON
        
        // Approach 1: Find the first complete JSON object
        String jsonCandidate = extractFirstCompleteJsonObject(text);
        if (jsonCandidate != null) {
            return jsonCandidate;
        }
        
        // Approach 2: Look for vulnerableSegments specifically
        jsonCandidate = extractVulnerabilityJson(text);
        if (jsonCandidate != null) {
            return jsonCandidate;
        }
        
        logger.warn("No valid JSON found in text: " + text.substring(0, Math.min(200, text.length())));
        return null;
    }
    
    private String extractFirstCompleteJsonObject(String text) {
        int firstBrace = text.indexOf('{');
        if (firstBrace == -1) {
            return null;
        }
        
        // Find the matching closing brace
        int braceCount = 0;
        int lastBrace = -1;
        
        for (int i = firstBrace; i < text.length(); i++) {
            char c = text.charAt(i);
            if (c == '{') {
                braceCount++;
            } else if (c == '}') {
                braceCount--;
                if (braceCount == 0) {
                    lastBrace = i;
                    break;
                }
            }
        }
        
        if (lastBrace == -1) {
            return null;
        }
        
        String jsonCandidate = text.substring(firstBrace, lastBrace + 1);
        
        // Validate that it's a valid JSON
        try {
            JSONObject json = new JSONObject(jsonCandidate);
            // Check if it has vulnerableSegments field
            if (json.has("vulnerableSegments")) {
                return jsonCandidate;
            }
        } catch (Exception e) {
            logger.debug("Invalid JSON candidate: " + jsonCandidate);
        }
        
        return null;
    }
    
    private String extractVulnerabilityJson(String text) {
        // Look for patterns like "vulnerableSegments" in the text
        String lowerText = text.toLowerCase();
        int vulnerableSegmentsIndex = lowerText.indexOf("vulnerablesegments");
        
        if (vulnerableSegmentsIndex == -1) {
            return null;
        }
        
        // Find the opening brace before vulnerableSegments
        int startIndex = text.lastIndexOf('{', vulnerableSegmentsIndex);
        if (startIndex == -1) {
            return null;
        }
        
        // Find the matching closing brace
        int braceCount = 0;
        int endIndex = -1;
        
        for (int i = startIndex; i < text.length(); i++) {
            char c = text.charAt(i);
            if (c == '{') {
                braceCount++;
            } else if (c == '}') {
                braceCount--;
                if (braceCount == 0) {
                    endIndex = i;
                    break;
                }
            }
        }
        
        if (endIndex == -1) {
            return null;
        }
        
        String jsonCandidate = text.substring(startIndex, endIndex + 1);
        
        // Validate that it's a valid JSON
        try {
            JSONObject json = new JSONObject(jsonCandidate);
            if (json.has("vulnerableSegments")) {
                return jsonCandidate;
            }
        } catch (Exception e) {
            logger.debug("Invalid vulnerability JSON candidate: " + jsonCandidate);
        }
        
        return null;
    }
}
